热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

句子|词表_NLP新闻主题分类任务

篇首语:本文由编程笔记#小编为大家整理,主要介绍了NLP-新闻主题分类任务相关的知识,希望对你有一定的参考价值。在看黑马的NLP的实践项目AI深度学习自然语言处理NLP零

篇首语:本文由编程笔记#小编为大家整理,主要介绍了NLP-新闻主题分类任务相关的知识,希望对你有一定的参考价值。


在看黑马的NLP的实践项目AI深度学习自然语言处理NLP零基础入门,可能由于版本的原因,完全按照课上的来无法运行,就参考实现了一遍,在这记录一下。

目录

1.用到的包

2.新闻主题分类数据

3.处理数据集

4.构建模型

 5.训练

5.1.generate_batch

5.2.训练 & 验证函数

 5.3.主流程




windows系统,jupyter notebook,torch:1.11.0+cu113


1.用到的包

import torch
import torchtext
import os
from keras.preprocessing.text import Tokenizer
from keras.preprocessing import sequence
import string
import re
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data.dataset import random_split
import time
from torch.utils.data import DataLoader

2.新闻主题分类数据

这边按课程的会报错,去网上查了torchtext.datasets.AG_NEWS,但是奇怪的是,看网上的资料会下载数据,我这边电脑里没有数据,不过代码能读到数据,也就没管数据下载不下来的问题了。

load_data_path = "../data"
if not os.path.isdir(load_data_path):
os.mkdir(load_data_path)

train_dataset, test_dataset = torchtext.datasets.AG_NEWS(
root='../data/', split=('train', 'test'))

看一下数据:一个样本是一个元组,第一个元素是int类型,表示label,第二个是str类型

 数据基本信息:

训练集有120000个样本,标签是共有4个取值:1,2,3,4。各类标签在训练集测试集分布比较均匀。


3.处理数据集

功能:1.将\\替换成空格(即将其两边的单词拆分成两个单词),将所有字母转换成小写。2.将label转换成[0,3]。3.句子长度截取

punct = str.maketrans('','',string.punctuation)
def process_datasets_by_Tokenizer(train_dataset, test_dataset, seq_len=200):
"""
参数:
train_dataset: 训练样本列表list(tuple(int, str))
返回:
train_dataset: 训练集列表list(tuple(tensor, int))
"""
tokenizer = Tokenizer()
train_dataset_texts, train_dataset_labels = [], []
test_dataset_texts, test_dataset_labels = [], []

for label, text in train_dataset:
# 前面的打印可以看到,存在\\\\这种,这边替换成空格,并所有的均为小写字母
train_dataset_texts.append(text.replace('\\\\',' ').translate(punct).lower())
train_dataset_labels.append(label - 1) # 将标签映射到[0,3]

for label, text in test_dataset:
test_dataset_texts.append(text.replace('\\\\',' ').translate(punct).lower())
test_dataset_labels.append(label - 1)

# 这边图省事,把训练集测试集合在一起构建词表,这样就不存在未登录词了
all_dataset_texts = train_dataset_texts + test_dataset_texts
all_dataset_labels = train_dataset_labels + test_dataset_labels
tokenizer.fit_on_texts(all_dataset_texts)

# train_dataset_seqs 是一个列表,其中的每一个元素是 将句子由文本表示 变换成 词表中的索引表示的列表
train_dataset_seqs = tokenizer.texts_to_sequences(train_dataset_texts)
test_datase_seqs = tokenizer.texts_to_sequences(test_dataset_texts)
# print(type(train_dataset_seqs), type(train_dataset_seqs[0])) #
# print(train_dataset_seqs)

# 截取前seq_len个,不足后面补0
# train_dataset_seqs是一个tensor,size:(样本数目, seq_len)
train_dataset_seqs = torch.tensor(sequence.pad_sequences(
train_dataset_seqs, seq_len, padding='post'), dtype=torch.int32)
test_datase_seqs = torch.tensor(sequence.pad_sequences(
test_datase_seqs, seq_len, padding='post'), dtype=torch.int32)
# print(type(train_dataset_seqs), type(train_dataset_seqs[0])) #
# print(train_dataset_seqs)

train_dataset = list(zip(train_dataset_seqs, train_dataset_labels))
test_dataset = list(zip(test_datase_seqs, test_dataset_labels))

vocab_size = len(tokenizer.index_word.keys())
num_class = len(set(all_dataset_labels))
return train_dataset, test_dataset, vocab_size, num_class
embed_dim = 16 # 大概9w个词,这边embedding维度射为16
batch_size = 64
seq_len = 50 # 句子长度取50就能覆盖90%以上的样本
train_dataset, test_dataset, vocab_size, num_class = process_datasets_by_Tokenizer(
train_dataset, test_dataset, seq_len=seq_len)
print(train_dataset[:2])
print("vocab_size = , num_class = ".format(vocab_size, num_class))


[(tensor([ 393, 395, 1571, 14750, 100, 54, 1, 838, 23, 23,
41233, 393, 1973, 10474, 3348, 4, 41234, 34, 3999, 763,
295, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
dtype=torch.int32), 2), (tensor([15798, 1041, 824, 1259, 4230, 23, 23, 898, 770, 305,
15798, 87, 90, 21, 3, 4444, 8, 537, 41235, 6,
15799, 1459, 2085, 5, 1, 490, 228, 21, 3877, 2345,
14, 6498, 7, 185, 333, 4, 1, 112, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
dtype=torch.int32), 2)]
vocab_size = 91629, num_class = 4

将注释掉的4条print语句打开,我们测试一下代码:

train = [(1, 'The moon is light'),
(2, 'This is the last rose of summer')]
test = train[:]
train, test, sz, cls = process_datasets_by_Tokenizer(train, test, seq_len=5)
train, test, sz, cls

得到输出:

 分析了一下样本中句子的长度:其中超过90%的句子长度都不超过50,故后续截取50个单词。


4.构建模型

模型结构简单:embedding层 + 平均池化层 + 全连接层

class TextSentiment(nn.Module):
"""文本分类模型"""
def __init__(self, vocab_size, embed_dim, num_class, seq_len):
"""
description: 类的初始化函数
:param vocab_size: 整个语料包含的不同词汇总数
:param embed_dim: 指定词嵌入的维度
:param num_class: 文本分类的类别总数
"""
super(TextSentiment, self).__init__()

self.seq_len = seq_len
self.embed_dim = embed_dim

# 实例化embedding层, sparse=True代表每次对该层求解梯度时, 只更新部分权重.
self.embedding = nn.Embedding(vocab_size, embed_dim, sparse=True)
# 实例化线性层, 参数分别是embed_dim和num_class.
self.fc = nn.Linear(embed_dim, num_class)
# 为各层初始化权重
self.init_weights()

def init_weights(self):
"""初始化权重函数"""
# 指定初始权重的取值范围数
initrange = 0.5
# 各层的权重参数都是初始化为均匀分布
self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc.weight.data.uniform_(-initrange, initrange)
# 偏置初始化为0
self.fc.bias.data.zero_()

def forward(self, text):
"""
:param text: 文本数值映射后的结果
:return: 与类别数尺寸相同的张量, 用以判断文本类别
"""
# [batch_size, seq_len, embed_dim]
embedded = self.embedding(text)
# [batch_size, embed_dim, seq_len],
# 后续将句子所在的维度做pooling,所以将句子所在维度放到最后面
embedded = embedded.transpose(2, 1) # 句子所在维度由原先的第二维变成第三维
# [batch_size, embed_dim, 1]
embedded = F.avg_pool1d(embedded, kernel_size=self.seq_len)
# [embed_dim, batch_size]
embedded = embedded.squeeze(-1)
# [batch_size, embed_dim]
# 看到torch.nn.CrossEntropyLoss()自带了softmax,所以这边不再套softmax
return self.fc(embedded)

 5.训练

5.1.generate_batch

generate_batch:构建一个批次内的数据,后续作为DataLoader函数的参数传入

def generate_batch(batch):
"""[summary]
Args:
batch ([type]): [description] 由样本张量和对应标签的元祖 组成的 batch_size 大小的列表
[(sample1, label1), (sample2, label2), ..., (samplen, labeln)]
:return 样本张量和标签各自的列表形式(Tensor)
"""
text = [entry[0].reshape(1, -1) for entry in batch]
# print(text)
label = torch.tensor([entry[1] for entry in batch])
text = torch.cat(text, dim=0)

return torch.tensor(text), torch.tensor(label)

我们测试一下这段的效果:

batch = [(torch.tensor([3, 23, 2, 8]), 1), (torch.tensor([3, 45, 21, 6]), 0)]
res = generate_batch(batch)
print(res, res[0].size())

输出:


5.2.训练 & 验证函数

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def run(data, batch_size, model, criterion,
mode='train', optimizer=None, scheduler=None):
total_loss, total_acc = 0., 0.

shuffle = False
if mode == 'train':
shuffle = True
data = DataLoader(data, batch_size=batch_size, shuffle=shuffle,
collate_fn=generate_batch)
for i, (text, label) in enumerate(data):
# text = text.to(device) # gpu版本
# label = label.to(device)
sz = text.size(0)
if mode == 'train':
optimizer.zero_grad()
output = model(text)
loss = criterion(output, label)
# 累计批次平均,参照蓄水池抽样算法
total_loss = i / (i + 1) * total_loss + loss.item() / sz / (i + 1)
loss.backward()
optimizer.step()
# predict = F.softmax(output, dim=-1)
correct_cnt = (output.argmax(1) == label).sum().item()
total_acc = i / (i + 1) * total_acc + correct_cnt / sz / (i + 1)
else:
with torch.no_grad():
output = model(text)
loss = criterion(output, label)
total_loss = i / (i + 1) * total_loss + loss.item() / sz / (i + 1)
# predict = F.softmax(output, dim=-1)
correct_cnt = (output.argmax(1) == label).sum().item()
total_acc = i / (i + 1) * total_acc + correct_cnt / sz / (i + 1)

# if i % 10 == 0:
# print("i: , loss: ".format(i, total_loss))

# 调整优化器学习率
if (scheduler):
scheduler.step()
# print(total_loss, total_acc, total_loss / count, total_acc / count, count)
return total_loss , total_acc

 5.3.主流程

model = TextSentiment(vocab_size + 1, embed_dim, num_class, seq_len)
# model = TextSentiment(vocab_size + 1, embed_dim, num_class, seq_len).to(device) # gpu版本
criterion = torch.nn.CrossEntropyLoss() # 自带了softmax
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.99)
train_len = int(len(train_dataset) * 0.95)
sub_train_, sub_valid_ = random_split(train_dataset,
[train_len, len(train_dataset) - train_len])
n_epochs = 10
for epoch in range(n_epochs):
start_time = time.time()
train_loss, train_acc = run(sub_train_, batch_size, model, criterion,
mode='train', optimizer=optimizer, scheduler=scheduler)

valid_loss, valid_acc = run(sub_train_, batch_size, model, criterion, mode='validation')

secs = int(time.time() - start_time)
mins = secs / 60
secs = secs % 60

print("Epoch: %d" % (epoch + 1),
" | time in %d minutes, %d seconds" % (mins, secs))
print(
f"\\tLoss: train_loss:.4f(train)\\t|\\tAcc: train_acc * 100:.1f%(train)"
)
print(
f"\\tLoss: valid_loss:.4f(valid)\\t|\\tAcc: valid_acc * 100:.1f%(valid)"
)

打印结果如下:



推荐阅读
  • 本文介绍了lua语言中闭包的特性及其在模式匹配、日期处理、编译和模块化等方面的应用。lua中的闭包是严格遵循词法定界的第一类值,函数可以作为变量自由传递,也可以作为参数传递给其他函数。这些特性使得lua语言具有极大的灵活性,为程序开发带来了便利。 ... [详细]
  • 本文分享了一个关于在C#中使用异步代码的问题,作者在控制台中运行时代码正常工作,但在Windows窗体中却无法正常工作。作者尝试搜索局域网上的主机,但在窗体中计数器没有减少。文章提供了相关的代码和解决思路。 ... [详细]
  • 本文介绍了使用Java实现大数乘法的分治算法,包括输入数据的处理、普通大数乘法的结果和Karatsuba大数乘法的结果。通过改变long类型可以适应不同范围的大数乘法计算。 ... [详细]
  • 1,关于死锁的理解死锁,我们可以简单的理解为是两个线程同时使用同一资源,两个线程又得不到相应的资源而造成永无相互等待的情况。 2,模拟死锁背景介绍:我们创建一个朋友 ... [详细]
  • 《数据结构》学习笔记3——串匹配算法性能评估
    本文主要讨论串匹配算法的性能评估,包括模式匹配、字符种类数量、算法复杂度等内容。通过借助C++中的头文件和库,可以实现对串的匹配操作。其中蛮力算法的复杂度为O(m*n),通过随机取出长度为m的子串作为模式P,在文本T中进行匹配,统计平均复杂度。对于成功和失败的匹配分别进行测试,分析其平均复杂度。详情请参考相关学习资源。 ... [详细]
  • Iamtryingtomakeaclassthatwillreadatextfileofnamesintoanarray,thenreturnthatarra ... [详细]
  • Spring源码解密之默认标签的解析方式分析
    本文分析了Spring源码解密中默认标签的解析方式。通过对命名空间的判断,区分默认命名空间和自定义命名空间,并采用不同的解析方式。其中,bean标签的解析最为复杂和重要。 ... [详细]
  • 本文介绍了C#中生成随机数的三种方法,并分析了其中存在的问题。首先介绍了使用Random类生成随机数的默认方法,但在高并发情况下可能会出现重复的情况。接着通过循环生成了一系列随机数,进一步突显了这个问题。文章指出,随机数生成在任何编程语言中都是必备的功能,但Random类生成的随机数并不可靠。最后,提出了需要寻找其他可靠的随机数生成方法的建议。 ... [详细]
  • 本文介绍了如何使用php限制数据库插入的条数并显示每次插入数据库之间的数据数目,以及避免重复提交的方法。同时还介绍了如何限制某一个数据库用户的并发连接数,以及设置数据库的连接数和连接超时时间的方法。最后提供了一些关于浏览器在线用户数和数据库连接数量比例的参考值。 ... [详细]
  • 本文讨论了如何优化解决hdu 1003 java题目的动态规划方法,通过分析加法规则和最大和的性质,提出了一种优化的思路。具体方法是,当从1加到n为负时,即sum(1,n)sum(n,s),可以继续加法计算。同时,还考虑了两种特殊情况:都是负数的情况和有0的情况。最后,通过使用Scanner类来获取输入数据。 ... [详细]
  • 本文介绍了C#中数据集DataSet对象的使用及相关方法详解,包括DataSet对象的概述、与数据关系对象的互联、Rows集合和Columns集合的组成,以及DataSet对象常用的方法之一——Merge方法的使用。通过本文的阅读,读者可以了解到DataSet对象在C#中的重要性和使用方法。 ... [详细]
  • 本文介绍了使用PHP实现断点续传乱序合并文件的方法和源码。由于网络原因,文件需要分割成多个部分发送,因此无法按顺序接收。文章中提供了merge2.php的源码,通过使用shuffle函数打乱文件读取顺序,实现了乱序合并文件的功能。同时,还介绍了filesize、glob、unlink、fopen等相关函数的使用。阅读本文可以了解如何使用PHP实现断点续传乱序合并文件的具体步骤。 ... [详细]
  • 本文讨论了一个关于cuowu类的问题,作者在使用cuowu类时遇到了错误提示和使用AdjustmentListener的问题。文章提供了16个解决方案,并给出了两个可能导致错误的原因。 ... [详细]
  • 本文详细介绍了Linux中进程控制块PCBtask_struct结构体的结构和作用,包括进程状态、进程号、待处理信号、进程地址空间、调度标志、锁深度、基本时间片、调度策略以及内存管理信息等方面的内容。阅读本文可以更加深入地了解Linux进程管理的原理和机制。 ... [详细]
  • 本文介绍了一个在线急等问题解决方法,即如何统计数据库中某个字段下的所有数据,并将结果显示在文本框里。作者提到了自己是一个菜鸟,希望能够得到帮助。作者使用的是ACCESS数据库,并且给出了一个例子,希望得到的结果是560。作者还提到自己已经尝试了使用"select sum(字段2) from 表名"的语句,得到的结果是650,但不知道如何得到560。希望能够得到解决方案。 ... [详细]
author-avatar
邵crnich
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有